import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def model_converter():
    model = torch.load('resnet50.pth').to(device)  # 这里保存的是完整模型
    model.eval()

    dummy_input = torch.randn(1, 3, 96, 96, device=device)
    input_names = ['data']
    output_names = ['fc']
    torch.onnx.export(model, dummy_input, 'resnet50.onnx',
                      export_params=True,
                      verbose=True,
                      input_names=input_names,
                      output_names=output_names)

